
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the BSD 3-Clause License  (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
import traceback


from ..builder import LOSSES
from .utils import weight_reduce_loss


def cross_entropy(pred,
                  label,
                  weight=None,
                  reduction='mean',
                  avg_factor=None,
                  class_weight=None):
    """Calculate the CrossEntropy loss.

    Args:
        pred (torch.Tensor): The prediction with shape (N, C), C is the number
            of classes.
        label (torch.Tensor): The learning label of the prediction.
        weight (torch.Tensor, optional): Sample-wise loss weight.
        reduction (str, optional): The method used to reduce the loss.
        avg_factor (int, optional): Average factor that is used to average
            the loss. Defaults to None.
        class_weight (list[float], optional): The weight for each class.

    Returns:
        torch.Tensor: The calculated loss
    """
    # element-wise losses
    loss = F.cross_entropy(pred, label, weight=class_weight, reduction='none')
#     print("cross_entrophy",pred.shape)
#     print("cross entrophy label",label.shape)
#     print("cross entrophy weight",weight.shape,weight.sum())
    
#     print("cross entrophy reduction",reduction)
#     traceback.print_stack()
#     print("cross entrophy avg_factor",avg_factor)
#     print("weight:",weight.size(),weight.sum())
#     print("loss before:",loss.sum())
#     print("avg_factor:",avg_factor)
#     print("reduction:",reduction)

    # apply weights and do the reduction
    if weight is not None:
        weight = weight.float()
    loss = weight_reduce_loss(
        loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
#     print("loss after:", loss)
#     print("cross entrophy loss",loss.shape)
#     loss = loss.sum()
    
    return loss


def _expand_onehot_labels_raw(labels, label_weights, label_channels):
    bin_labels = labels.new_full((labels.size(0), label_channels), 0)
#     print(torch.npu.synchronize(), "bin_labels 1", bin_labels.dtype)
    inds = ((labels >= 0) & (labels < label_channels))
#     print(torch.npu.synchronize(), "bin_labels 2")
    labels = labels.long()
    if inds.any() > 0:
#         print(labels.shape, labels.max(), labels.min())
#         print(bin_labels.shape, inds.shape, (labels*inds).shape)
        #NPU
        bin_labels[inds,labels*inds] = 1
        #bin_labels[inds, labels[inds]] = 1
#     print(torch.npu.synchronize(), "bin_labels 3")
    if label_weights is None:
        bin_label_weights = None
    else:
        bin_label_weights = label_weights.view(-1, 1).expand(
            label_weights.size(0), label_channels)
#     print(torch.npu.synchronize(), "bin_labels",bin_labels.shape)
#     print(torch.npu.synchronize(), "bin_label_weights",bin_label_weights.shape)

    return bin_labels, bin_label_weights



def _expand_onehot_labels(labels, label_weights, label_channels):
    inds = (labels >= 0) & (labels < label_channels)
    labels = torch.clamp(labels, 0, label_channels-1)
    labels_one_hot = F.one_hot(labels, label_channels)

    if label_weights is None:
        bin_label_weights = None
    else:
        bin_label_weights = label_weights.view(-1, 1).expand(
            label_weights.size(0), label_channels)
    return inds.unsqueeze(-1) * labels_one_hot, bin_label_weights




def binary_cross_entropy(pred,
                         label,
                         weight=None,
                         reduction='mean',
                         avg_factor=None,
                         class_weight=None):
    """Calculate the binary CrossEntropy loss.

    Args:
        pred (torch.Tensor): The prediction with shape (N, 1).
        label (torch.Tensor): The learning label of the prediction.
        weight (torch.Tensor, optional): Sample-wise loss weight.
        reduction (str, optional): The method used to reduce the loss.
            Options are "none", "mean" and "sum".
        avg_factor (int, optional): Average factor that is used to average
            the loss. Defaults to None.
        class_weight (list[float], optional): The weight for each class.

    Returns:
        torch.Tensor: The calculated loss
    """
#     print(torch.npu.synchronize(), "before if pred.dim() != label.dim():")
#     print(label.shape, label.dtype, label.storage().npu_format())
#     print(weight.shape, weight.dtype, weight.storage().npu_format())
#     print(pred.shape, pred.dtype, pred.storage().npu_format())

    if pred.dim() != label.dim():
        label, weight = _expand_onehot_labels(label, weight, pred.size(-1))
#     print(torch.npu.synchronize(), "after if pred.dim() != label.dim():")
    

    # weighted element-wise losses
    if weight is not None:
        weight = weight.float()

#     print(torch.npu.synchronize(), "before enter weight_reduce_loss")
#     print(pred.shape, pred.dtype, pred.storage().npu_format())
#     print(label.shape, label.dtype, label.storage().npu_format())
    loss = F.binary_cross_entropy_with_logits(
        pred, label.float(), pos_weight=class_weight, reduction='none')



    # do the reduction for the weighted loss

#     print(torch.npu.synchronize(), "enter weight_reduce_loss")

    loss = weight_reduce_loss(
        loss, weight, reduction=reduction, avg_factor=avg_factor)

    return loss


def mask_cross_entropy(pred,
                       target,
                       label,
                       reduction='none',
                       avg_factor=None,
                       class_weight=None,
                       reduce_weight=None):
    """Calculate the CrossEntropy loss for masks.

    Args:
        pred (torch.Tensor): The prediction with shape (N, C), C is the number
            of classes.
        target (torch.Tensor): The learning label of the prediction.
        label (torch.Tensor): ``label`` indicates the class label of the mask'
            corresponding object. This will be used to select the mask in the
            of the class which the object belongs to when the mask prediction
            if not class-agnostic.
        reduction (str, optional): The method used to reduce the loss.
            Options are "none", "mean" and "sum".
        avg_factor (int, optional): Average factor that is used to average
            the loss. Defaults to None.
        class_weight (list[float], optional): The weight for each class.

    Returns:
        torch.Tensor: The calculated loss
    """
    # TODO: handle these two reserved arguments
#     print('reduction and avg factor:',reduction, avg_factor)
#     assert reduction == 'mean' and avg_factor is None
    num_rois = pred.size()[0]
    inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
    pred_slice = pred[inds, label].squeeze(1)
#     print('pred_slice size:',pred_slice.size())
#     print('target size:',target.size())
    
    if reduce_weight is not None:
#         print('get weight with size:',reduce_weight.size(),target.size())
        loss = F.binary_cross_entropy_with_logits(
        pred_slice, target, weight=class_weight, reduction='none')
        
        reduce_weight = reduce_weight.float().npu()
#         print('weight:',reduce_weight.sum())
#         print('loss:',loss.size())
        
        loss = weight_reduce_loss(
        loss, reduce_weight, reduction='none').sum()/(reduce_weight.size(1)*reduce_weight.size(2))

        return loss
    else:
        loss = F.binary_cross_entropy_with_logits(
        pred_slice, target, weight=class_weight, reduction='mean')[None]
        return loss
        
#     return loss
    
#     return F.binary_cross_entropy_with_logits(
#         pred_slice, target, weight=class_weight, reduction='mean')[None]
    


@LOSSES.register_module()
class CrossEntropyLoss(nn.Module):

    def __init__(self,
                 use_sigmoid=False,
                 use_mask=False,
                 reduction='mean',
                 class_weight=None,
                 loss_weight=1.0):
        """CrossEntropyLoss.

        Args:
            use_sigmoid (bool, optional): Whether the prediction uses sigmoid
                of softmax. Defaults to False.
            use_mask (bool, optional): Whether to use mask cross entropy loss.
                Defaults to False.
            reduction (str, optional): . Defaults to 'mean'.
                Options are "none", "mean" and "sum".
            class_weight (list[float], optional): Weight of each class.
                Defaults to None.
            loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
        """
        super(CrossEntropyLoss, self).__init__()
        assert (use_sigmoid is False) or (use_mask is False)
        self.use_sigmoid = use_sigmoid
        self.use_mask = use_mask
        self.reduction = reduction
        self.loss_weight = loss_weight
        self.class_weight = class_weight

        if self.use_sigmoid:
            self.cls_criterion = binary_cross_entropy
        elif self.use_mask:
            self.cls_criterion = mask_cross_entropy
        else:
            self.cls_criterion = cross_entropy

    def forward(self,
                cls_score,
                label,
                weight=None,
                avg_factor=None,
                reduction_override=None,
                **kwargs):
        """Forward function.

        Args:
            cls_score (torch.Tensor): The prediction.
            label (torch.Tensor): The learning label of the prediction.
            weight (torch.Tensor, optional): Sample-wise loss weight.
            avg_factor (int, optional): Average factor that is used to average
                the loss. Defaults to None.
            reduction (str, optional): The method used to reduce the loss.
                Options are "none", "mean" and "sum".
        Returns:
            torch.Tensor: The calculated loss
        """
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        if self.class_weight is not None:
            class_weight = cls_score.new_tensor(
                self.class_weight, device=cls_score.device)
        else:
            class_weight = None
        loss_cls = self.loss_weight * self.cls_criterion(
            cls_score,
            label,
            weight,
            class_weight=class_weight,
            reduction=reduction,
            avg_factor=avg_factor,
            **kwargs)
        return loss_cls
